{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3D Segmentation with UNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0+untagged.106.gf2e8580.dirty\n",
      "Python version: 3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31)  [GCC 7.3.0]\n",
      "Numpy version: 1.17.4\n",
      "Pytorch version: 1.4.0a0+a5b4d78\n",
      "Ignite version: 0.3.0\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import tempfile\n",
    "from glob import glob\n",
    "import logging\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "import torch\n",
    "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
    "from ignite.handlers import ModelCheckpoint\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import monai\n",
    "from monai.data import NiftiDataset, create_test_image_3d\n",
    "from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, ToTensor, RandSpatialCrop\n",
    "from monai.handlers import \\\n",
    "    StatsHandler, TensorBoardStatsHandler, TensorBoardImageHandler, MeanDice, stopping_fn_from_metric\n",
    "from monai.networks.utils import predict_segmentation\n",
    "\n",
    "monai.config.print_config()\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup demo data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a temporary directory and 50 random image, mask pairs\n",
    "tempdir = tempfile.mkdtemp()\n",
    "\n",
    "for i in range(50):\n",
    "    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup transforms, dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 1, 96, 96, 96]) torch.Size([10, 1, 96, 96, 96])\n"
     ]
    }
   ],
   "source": [
    "images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))\n",
    "segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))\n",
    "\n",
    "# Define transforms for image and segmentation\n",
    "imtrans = Compose([\n",
    "    ScaleIntensity(), \n",
    "    AddChannel(), \n",
    "    RandSpatialCrop((96, 96, 96), random_size=False), \n",
    "    ToTensor()\n",
    "])\n",
    "segtrans = Compose([\n",
    "    AddChannel(), \n",
    "    RandSpatialCrop((96, 96, 96), random_size=False), \n",
    "    ToTensor()\n",
    "])\n",
    "\n",
    "# Define nifti dataset, dataloader.\n",
    "ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans)\n",
    "loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())\n",
    "im, seg = monai.utils.misc.first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Model, Loss, Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create UNet, DiceLoss and Adam optimizer.\n",
    "net = monai.networks.nets.UNet(\n",
    "    dimensions=3,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(16, 32, 64, 128, 256),\n",
    "    strides=(2, 2, 2, 2),\n",
    "    num_res_units=2,\n",
    ")\n",
    "\n",
    "loss = monai.losses.DiceLoss(do_sigmoid=True)\n",
    "lr = 1e-3\n",
    "opt = torch.optim.Adam(net.parameters(), lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create supervised_trainer using ignite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create trainer\n",
    "device = torch.device(\"cuda:0\")\n",
    "trainer = create_supervised_trainer(net, opt, loss, device, False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup event handlers for checkpointing and logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "### optional section for checkpoint and tensorboard logging\n",
    "# adding checkpoint handler to save models (network params and optimizer stats) during training\n",
    "checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)\n",
    "trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,\n",
    "                          handler=checkpoint_handler,\n",
    "                          to_save={'net': net, 'opt': opt})\n",
    "# StatsHandler prints loss at every iteration and print metrics at every epoch,\n",
    "# we don't set metrics for trainer here, so just print loss, user can also customize print functions\n",
    "# and can use output_transform to convert engine.state.output if it's not a loss value\n",
    "train_stats_handler = StatsHandler(name='trainer')\n",
    "train_stats_handler.attach(trainer)\n",
    "\n",
    "\n",
    "# TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler\n",
    "train_tensorboard_stats_handler = TensorBoardStatsHandler()\n",
    "train_tensorboard_stats_handler.attach(trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Add Vadliation every N epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ignite.engine.engine.RemovableEventHandle at 0x7fa60714a978>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "### optional section for model validation during training\n",
    "validation_every_n_epochs = 1\n",
    "# Set parameters for validation\n",
    "metric_name = 'Mean_Dice'\n",
    "# add evaluation metric to the evaluator engine\n",
    "val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}\n",
    "\n",
    "# ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,\n",
    "# user can add output_transform to return other values\n",
    "evaluator = create_supervised_evaluator(net, val_metrics, device, True)\n",
    "\n",
    "# create a validation data loader\n",
    "val_imtrans = Compose([\n",
    "    ScaleIntensity(),\n",
    "    AddChannel(),\n",
    "    Resize((96, 96, 96)),\n",
    "    ToTensor()\n",
    "])\n",
    "val_segtrans = Compose([\n",
    "    AddChannel(),\n",
    "    Resize((96, 96, 96)),\n",
    "    ToTensor()\n",
    "])\n",
    "val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)\n",
    "val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())\n",
    "\n",
    "\n",
    "@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n",
    "def run_validation(engine):\n",
    "    evaluator.run(val_loader)\n",
    "\n",
    "\n",
    "# Add stats event handler to print validation stats via evaluator\n",
    "val_stats_handler = StatsHandler(\n",
    "    name='evaluator',\n",
    "    output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output\n",
    "    global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer\n",
    "val_stats_handler.attach(evaluator)\n",
    "\n",
    "# add handler to record metrics to TensorBoard at every validation epoch\n",
    "val_tensorboard_stats_handler = TensorBoardStatsHandler(\n",
    "    output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output\n",
    "    global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer\n",
    "val_tensorboard_stats_handler.attach(evaluator)\n",
    "\n",
    "# add handler to draw the first image and the corresponding label and model output in the last batch\n",
    "# here we draw the 3D output as GIF format along Depth axis, at every validation epoch\n",
    "val_tensorboard_image_handler = TensorBoardImageHandler(\n",
    "    batch_transform=lambda batch: (batch[0], batch[1]),\n",
    "    output_transform=lambda output: predict_segmentation(output[0]),\n",
    "    global_iter_transform=lambda x: trainer.state.epoch\n",
    ")\n",
    "evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a training data loader\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "\n",
    "train_ds = NiftiDataset(images[:20], segs[:20], transform=imtrans, seg_transform=segtrans)\n",
    "train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())\n",
    "\n",
    "train_epochs = 5\n",
    "state = trainer.run(train_loader, train_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing Tensorboard logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dir = './runs'  # by default TensorBoard logs go into './runs'\n",
    "\n",
    "%load_ext tensorboard\n",
    "%tensorboard --logdir $log_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf {tempdir}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
